//
// Copyright (c) 2014 Wirebird Labs LLC.  All rights reserved.
// Copyright 2020 Staysail Systems, Inc. <info@staysail.tech>
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom
// the Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included
// in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
// IN THE SOFTWARE.
//

#include "base64.h"

#include <ctype.h>

size_t
nni_base64_decode(const char *in, size_t in_len, uint8_t *out, size_t out_len)
{
	unsigned ii;
	unsigned rem;
	uint32_t v;
	uint8_t  ch;
	size_t   io;

	// Unrolled lookup of ASCII code points.
	// 0xFF represents a non-base64 valid character.
	const uint8_t decode[256] = { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0xFF, 0xFF,
		0x3F, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C,
		0x3D, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0xFF, 0xFF, 0x00, 0x01,
		0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B,
		0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
		0x16, 0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23,
		0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D,
		0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF };

	for (io = 0, ii = 0, v = 0, rem = 0; ii < in_len; ii++) {
		if (isspace(in[ii])) {
			continue;
		}

		if (in[ii] == '=') {
			break;
		}

		ch = decode[(int) (in[ii])];

		// Discard invalid characters as per RFC 2045.
		if (ch == 0xFF) {
			break;
		}

		v = (v << 6u) | ch;
		rem += 6;

		if (rem >= 8) {
			rem -= 8;
			if (io >= out_len)
				return (-1);
			out[io++] = (v >> rem) & 255u;
		}
	}
	if (rem >= 8) {
		rem -= 8;
		if (io >= out_len)
			return (-1);
		out[io++] = (v >> rem) & 255u;
	}
	return (io);
}

size_t
nni_base64_encode(const uint8_t *in, size_t in_len, char *out, size_t out_len)
{
	unsigned ii;
	unsigned rem;
	uint32_t v;
	size_t   io;

	const uint8_t encode[65] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
	                           "abcdefghijklmnopqrstuvwxyz"
	                           "0123456789+/";

	for (io = 0, ii = 0, v = 0, rem = 0; ii < in_len; ii++) {
		uint8_t ch = in[ii];
		v          = (v << 8u) | ch;
		rem += 8;
		while (rem >= 6) {
			rem -= 6;
			if (io >= out_len)
				return (-1);
			out[io++] = encode[(v >> rem) & 63u];
		}
	}

	if (rem) {
		v <<= (6 - rem);
		if (io >= out_len)
			return (-1);
		out[io++] = encode[v & 63u];
	}

	// Pad to a multiple of 3.
	while (io & 3u) {
		if (io >= out_len)
			return (-1);
		out[io++] = '=';
	}

	if (io >= out_len)
		return (-1);

	out[io] = '\0';

	return (io);
}
